--- title: 05c Multiscale Flow Embeddings with a Grid keywords: fastai sidebar: home_sidebar nb_path: "05c Multiscale Diffusion Flow Embedding.ipynb" ---
compute_grid[source]
compute_grid(X,grid_width=20)
Returns a grid of points which bounds the points X. The grid has 'grid_width' dots in both length and width. Accepts X, tensor of shape n x 2 Returns tensor of shape grid_width^2 x 2
A cornerstone of this method will be the powering of the diffusion matrix, which must intersperse jumping between points in the dataset and in the surrounding grid. Traditional matrix powering will do, provided we add the grid points to the dataset before powering, and then take them out afterwards.
diffusion_matrix_with_grid_points[source]
diffusion_matrix_with_grid_points(X,grid,flow_function,t,sigma,flow_strength)
A = torch.tensor([[1,2],[3,4],[5,6]])
B = torch.rand(5,2)
K = GaussianVectorField(2,25,device=torch.device('cpu'))
out = diffusion_matrix_with_grid_points(A,B,K,1, 0.5, 4)
out
tensor([[9.9995e-01, 4.7606e-05, 1.6631e-07],
[2.2208e-03, 9.9237e-01, 5.4120e-03],
[1.2162e-05, 3.4813e-03, 9.9651e-01]], grad_fn=<DivBackward0>)
The architecture should support modular swapping of point embedding methods. A few of these include:
For example, here is the simple FF net we use in the autoencoder:
class FeedForwardReLU[source]
FeedForwardReLU(shape) ::Module
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call :meth:to, etc.
.. note::
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
fa = FeedForwardReLU(shape=[2,3,4,5,6,5,4,3,2])
fa
FeedForwardReLU(
(FA): Sequential(
(0): Linear(in_features=2, out_features=3, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=3, out_features=4, bias=True)
(3): LeakyReLU(negative_slope=0.01)
(4): Linear(in_features=4, out_features=5, bias=True)
(5): LeakyReLU(negative_slope=0.01)
(6): Linear(in_features=5, out_features=6, bias=True)
(7): LeakyReLU(negative_slope=0.01)
(8): Linear(in_features=6, out_features=5, bias=True)
(9): LeakyReLU(negative_slope=0.01)
(10): Linear(in_features=5, out_features=4, bias=True)
(11): LeakyReLU(negative_slope=0.01)
(12): Linear(in_features=4, out_features=3, bias=True)
(13): LeakyReLU(negative_slope=0.01)
(14): Linear(in_features=3, out_features=2, bias=True)
)
)
class MultiscaleDiffusionFlowEmbedder[source]
MultiscaleDiffusionFlowEmbedder(X,flows,ts=(1, 2, 4, 8),sigma_graph=0.5,sigma_embedding=0.5,flow_strength_graph=5,flow_strength_embedding=5,embedding_dimension=2,learning_rate=0.001,flow_artist='ReLU',flow_artist_shape=(2, 4, 8, 4, 2),num_flow_gaussians=25,embedder=None,decoder=None,labels=None,loss_weights=None,device=device(type='cpu')) ::Module
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call :meth:to, etc.
.. note::
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
from directed_graphs.datasets import directed_circle, directed_cylinder, directed_spiral, directed_swiss_roll, directed_spiral_uniform, directed_swiss_roll_uniform
from directed_graphs.datasets import plot_directed_2d, plot_directed_3d
from directed_graphs.diffusion_flow_embedding import DiffusionFlowEmbedder
import torch
import numpy as np
X, flow, labels = directed_circle(num_nodes=300, radius=1)
plot_directed_2d(X, flow, labels)
X = torch.tensor(X).float().to(device)
flow = torch.tensor(flow).float().to(device)
MFE = MultiscaleDiffusionFlowEmbedder(X, flow, device=device).to(device)
MFE.fit(n_steps=1000)
0%| | 0/1000 [00:00<?, ?it/s]
Diffusion loss 0 is 2.03975510597229 Diffusion loss 1 is 2.73725962638855 Diffusion loss 2 is 3.0151374340057373 Diffusion loss 3 is 3.1220245361328125 EPOCH 0. Loss 3.1220245361328125. Flow strength 5.0. Heatmap of P embedding is
0%| | 2/1000 [00:01<11:27, 1.45it/s]
Diffusion loss 0 is 2.0292611122131348 Diffusion loss 1 is 2.730095863342285 Diffusion loss 2 is 3.0079665184020996 Diffusion loss 3 is 3.1147098541259766
0%| | 3/1000 [00:01<09:50, 1.69it/s]
Diffusion loss 0 is 2.0189754962921143 Diffusion loss 1 is 2.7233328819274902 Diffusion loss 2 is 3.0012006759643555 Diffusion loss 3 is 3.107822895050049
0%| | 4/1000 [00:02<09:01, 1.84it/s]
Diffusion loss 0 is 2.012047052383423 Diffusion loss 1 is 2.718916893005371 Diffusion loss 2 is 2.996764898300171 Diffusion loss 3 is 3.1032729148864746
0%| | 5/1000 [00:02<08:37, 1.92it/s]
Diffusion loss 0 is 2.0070102214813232 Diffusion loss 1 is 2.7159461975097656 Diffusion loss 2 is 2.993769884109497 Diffusion loss 3 is 3.1001739501953125
1%| | 6/1000 [00:03<08:21, 1.98it/s]
Diffusion loss 0 is 2.00260591506958 Diffusion loss 1 is 2.7136311531066895 Diffusion loss 2 is 2.9914305210113525 Diffusion loss 3 is 3.097743034362793
1%| | 7/1000 [00:03<08:19, 1.99it/s]
Diffusion loss 0 is 1.9991872310638428 Diffusion loss 1 is 2.7117888927459717 Diffusion loss 2 is 2.9895663261413574 Diffusion loss 3 is 3.095801830291748
1%| | 8/1000 [00:04<08:13, 2.01it/s]
Diffusion loss 0 is 1.9976664781570435 Diffusion loss 1 is 2.711923599243164 Diffusion loss 2 is 2.9896810054779053 Diffusion loss 3 is 3.095852851867676
1%| | 9/1000 [00:04<08:13, 2.01it/s]
Diffusion loss 0 is 1.99664306640625 Diffusion loss 1 is 2.7116823196411133 Diffusion loss 2 is 2.989410638809204 Diffusion loss 3 is 3.0955164432525635
1%| | 10/1000 [00:05<08:09, 2.02it/s]
Diffusion loss 0 is 1.9967330694198608 Diffusion loss 1 is 2.7120280265808105 Diffusion loss 2 is 2.9897289276123047 Diffusion loss 3 is 3.095778465270996
1%| | 11/1000 [00:05<07:59, 2.06it/s]
Diffusion loss 0 is 1.998558759689331 Diffusion loss 1 is 2.7139649391174316 Diffusion loss 2 is 2.9916367530822754 Diffusion loss 3 is 3.09763503074646
1%| | 12/1000 [00:06<08:02, 2.05it/s]
Diffusion loss 0 is 1.9998202323913574 Diffusion loss 1 is 2.715221405029297 Diffusion loss 2 is 2.992865562438965 Diffusion loss 3 is 3.0988197326660156
1%|▏ | 13/1000 [00:06<07:56, 2.07it/s]
Diffusion loss 0 is 2.0002341270446777 Diffusion loss 1 is 2.71618390083313 Diffusion loss 2 is 2.993809461593628 Diffusion loss 3 is 3.099735736846924
1%|▏ | 14/1000 [00:07<07:55, 2.07it/s]
Diffusion loss 0 is 2.002619981765747 Diffusion loss 1 is 2.7180428504943848 Diffusion loss 2 is 2.9956412315368652 Diffusion loss 3 is 3.1015336513519287
2%|▏ | 15/1000 [00:07<07:49, 2.10it/s]
Diffusion loss 0 is 2.0064282417297363 Diffusion loss 1 is 2.720560312271118 Diffusion loss 2 is 2.998126268386841 Diffusion loss 3 is 3.1039843559265137
2%|▏ | 16/1000 [00:08<07:50, 2.09it/s]
Diffusion loss 0 is 2.0098114013671875 Diffusion loss 1 is 2.7236435413360596 Diffusion loss 2 is 3.0011842250823975 Diffusion loss 3 is 3.1070220470428467
2%|▏ | 17/1000 [00:08<07:59, 2.05it/s]
Diffusion loss 0 is 2.0133814811706543 Diffusion loss 1 is 2.726738691329956 Diffusion loss 2 is 3.0042521953582764 Diffusion loss 3 is 3.110074281692505
2%|▏ | 18/1000 [00:09<07:55, 2.07it/s]
Diffusion loss 0 is 2.016296863555908 Diffusion loss 1 is 2.728565216064453 Diffusion loss 2 is 3.006044626235962 Diffusion loss 3 is 3.1118476390838623
2%|▏ | 19/1000 [00:09<08:01, 2.04it/s]
Diffusion loss 0 is 2.0210165977478027 Diffusion loss 1 is 2.7318115234375 Diffusion loss 2 is 3.009251117706299 Diffusion loss 3 is 3.1150357723236084
2%|▏ | 20/1000 [00:10<07:59, 2.04it/s]
Diffusion loss 0 is 2.024796724319458 Diffusion loss 1 is 2.734341621398926 Diffusion loss 2 is 3.0117392539978027 Diffusion loss 3 is 3.11751127243042
2%|▏ | 21/1000 [00:10<07:54, 2.06it/s]
Diffusion loss 0 is 2.0307071208953857 Diffusion loss 1 is 2.7386395931243896 Diffusion loss 2 is 3.015990734100342 Diffusion loss 3 is 3.121751070022583
2%|▏ | 22/1000 [00:11<07:53, 2.07it/s]
Diffusion loss 0 is 2.037219762802124 Diffusion loss 1 is 2.743382453918457 Diffusion loss 2 is 3.0206847190856934 Diffusion loss 3 is 3.126438617706299
2%|▏ | 23/1000 [00:11<08:02, 2.03it/s]
Diffusion loss 0 is 2.0420148372650146 Diffusion loss 1 is 2.7462852001190186 Diffusion loss 2 is 3.023536205291748 Diffusion loss 3 is 3.1292808055877686
2%|▏ | 24/1000 [00:12<07:57, 2.04it/s]
Diffusion loss 0 is 2.0466208457946777 Diffusion loss 1 is 2.7490015029907227 Diffusion loss 2 is 3.026197910308838 Diffusion loss 3 is 3.131936550140381
2%|▎ | 25/1000 [00:12<07:50, 2.07it/s]
Diffusion loss 0 is 2.0530755519866943 Diffusion loss 1 is 2.7535979747772217 Diffusion loss 2 is 3.0307319164276123 Diffusion loss 3 is 3.1364657878875732
3%|▎ | 26/1000 [00:13<07:47, 2.09it/s]
Diffusion loss 0 is 2.0579400062561035 Diffusion loss 1 is 2.7561774253845215 Diffusion loss 2 is 3.0332400798797607 Diffusion loss 3 is 3.138969898223877
3%|▎ | 27/1000 [00:13<07:44, 2.09it/s]
Diffusion loss 0 is 2.064486026763916 Diffusion loss 1 is 2.7601006031036377 Diffusion loss 2 is 3.0370841026306152 Diffusion loss 3 is 3.1428120136260986
3%|▎ | 28/1000 [00:14<07:44, 2.09it/s]
Diffusion loss 0 is 2.072329044342041 Diffusion loss 1 is 2.7650146484375 Diffusion loss 2 is 3.0419058799743652 Diffusion loss 3 is 3.147632360458374
3%|▎ | 29/1000 [00:14<07:42, 2.10it/s]
Diffusion loss 0 is 2.0767548084259033 Diffusion loss 1 is 2.7665822505950928 Diffusion loss 2 is 3.0433719158172607 Diffusion loss 3 is 3.149101495742798
3%|▎ | 30/1000 [00:14<07:40, 2.11it/s]
Diffusion loss 0 is 2.085530996322632 Diffusion loss 1 is 2.772172451019287 Diffusion loss 2 is 3.0488407611846924 Diffusion loss 3 is 3.154575824737549
3%|▎ | 31/1000 [00:15<07:46, 2.08it/s]
Diffusion loss 0 is 2.0957610607147217 Diffusion loss 1 is 2.7790567874908447 Diffusion loss 2 is 3.055595874786377 Diffusion loss 3 is 3.161344051361084
3%|▎ | 32/1000 [00:15<07:49, 2.06it/s]
Diffusion loss 0 is 2.1112711429595947 Diffusion loss 1 is 2.791167736053467 Diffusion loss 2 is 3.0675556659698486 Diffusion loss 3 is 3.173318862915039
3%|▎ | 33/1000 [00:16<07:46, 2.07it/s]
Diffusion loss 0 is 2.124772548675537 Diffusion loss 1 is 2.800870418548584 Diffusion loss 2 is 3.077091932296753 Diffusion loss 3 is 3.1828761100769043
3%|▎ | 34/1000 [00:16<07:38, 2.10it/s]
Diffusion loss 0 is 2.132888078689575 Diffusion loss 1 is 2.8044772148132324 Diffusion loss 2 is 3.0804951190948486 Diffusion loss 3 is 3.1862988471984863
4%|▎ | 35/1000 [00:17<07:40, 2.09it/s]
Diffusion loss 0 is 2.1520133018493652 Diffusion loss 1 is 2.8193533420562744 Diffusion loss 2 is 3.095139265060425 Diffusion loss 3 is 3.2009644508361816
4%|▎ | 36/1000 [00:17<07:38, 2.10it/s]
Diffusion loss 0 is 2.1742148399353027 Diffusion loss 1 is 2.8369531631469727 Diffusion loss 2 is 3.112487316131592 Diffusion loss 3 is 3.218341827392578
4%|▎ | 37/1000 [00:18<07:32, 2.13it/s]
Diffusion loss 0 is 2.199885606765747 Diffusion loss 1 is 2.8578600883483887 Diffusion loss 2 is 3.1331076622009277 Diffusion loss 3 is 3.2389938831329346
4%|▍ | 38/1000 [00:18<07:33, 2.12it/s]
Diffusion loss 0 is 2.219031572341919 Diffusion loss 1 is 2.8722925186157227 Diffusion loss 2 is 3.14723539352417 Diffusion loss 3 is 3.2531609535217285
4%|▍ | 39/1000 [00:19<07:29, 2.14it/s]
Diffusion loss 0 is 2.250105142593384 Diffusion loss 1 is 2.8986165523529053 Diffusion loss 2 is 3.173224925994873 Diffusion loss 3 is 3.279193639755249
4%|▍ | 40/1000 [00:19<07:32, 2.12it/s]
Diffusion loss 0 is 2.2882821559906006 Diffusion loss 1 is 2.932119846343994 Diffusion loss 2 is 3.206362009048462 Diffusion loss 3 is 3.3123767375946045
4%|▍ | 41/1000 [00:20<07:30, 2.13it/s]
Diffusion loss 0 is 2.322924852371216 Diffusion loss 1 is 2.962155818939209 Diffusion loss 2 is 3.236011028289795 Diffusion loss 3 is 3.3420796394348145
4%|▍ | 42/1000 [00:20<07:32, 2.12it/s]
Diffusion loss 0 is 2.3640329837799072 Diffusion loss 1 is 2.998786449432373 Diffusion loss 2 is 3.272233009338379 Diffusion loss 3 is 3.378357410430908
4%|▍ | 43/1000 [00:21<07:34, 2.11it/s]
Diffusion loss 0 is 2.4036333560943604 Diffusion loss 1 is 3.0341620445251465 Diffusion loss 2 is 3.3071846961975098 Diffusion loss 3 is 3.41336727142334
4%|▍ | 44/1000 [00:21<07:34, 2.10it/s]
Diffusion loss 0 is 2.454817295074463 Diffusion loss 1 is 3.0812571048736572 Diffusion loss 2 is 3.353847026824951 Diffusion loss 3 is 3.460090398788452
4%|▍ | 45/1000 [00:22<07:42, 2.07it/s]
Diffusion loss 0 is 2.505533218383789 Diffusion loss 1 is 3.1282639503479004 Diffusion loss 2 is 3.4004430770874023 Diffusion loss 3 is 3.5067505836486816
5%|▍ | 46/1000 [00:22<07:42, 2.06it/s]
Diffusion loss 0 is 2.5521721839904785 Diffusion loss 1 is 3.1716599464416504 Diffusion loss 2 is 3.4434680938720703 Diffusion loss 3 is 3.5498480796813965
5%|▍ | 47/1000 [00:23<07:41, 2.06it/s]
Diffusion loss 0 is 2.601017951965332 Diffusion loss 1 is 3.21734619140625 Diffusion loss 2 is 3.488769769668579 Diffusion loss 3 is 3.595217227935791
5%|▍ | 48/1000 [00:23<07:57, 1.99it/s]
Diffusion loss 0 is 2.6678476333618164 Diffusion loss 1 is 3.2811412811279297 Diffusion loss 2 is 3.5521552562713623 Diffusion loss 3 is 3.6586532592773438
5%|▍ | 49/1000 [00:24<07:53, 2.01it/s]
Diffusion loss 0 is 2.7517495155334473 Diffusion loss 1 is 3.3624765872955322 Diffusion loss 2 is 3.6331326961517334 Diffusion loss 3 is 3.739685297012329
5%|▌ | 50/1000 [00:24<07:52, 2.01it/s]
Diffusion loss 0 is 2.833007574081421 Diffusion loss 1 is 3.4416794776916504 Diffusion loss 2 is 3.7120327949523926 Diffusion loss 3 is 3.8186583518981934
5%|▌ | 51/1000 [00:25<07:52, 2.01it/s]
Diffusion loss 0 is 2.9055113792419434 Diffusion loss 1 is 3.5122835636138916 Diffusion loss 2 is 3.782335042953491 Diffusion loss 3 is 3.889025926589966
5%|▌ | 52/1000 [00:25<07:48, 2.02it/s]
Diffusion loss 0 is 2.983106851577759 Diffusion loss 1 is 3.5881240367889404 Diffusion loss 2 is 3.857867956161499 Diffusion loss 3 is 3.964601516723633
5%|▌ | 53/1000 [00:26<07:45, 2.03it/s]
Diffusion loss 0 is 3.0704171657562256 Diffusion loss 1 is 3.674048900604248 Diffusion loss 2 is 3.943547248840332 Diffusion loss 3 is 4.050338268280029
5%|▌ | 54/1000 [00:26<07:54, 1.99it/s]
Diffusion loss 0 is 3.1739323139190674 Diffusion loss 1 is 3.776221990585327 Diffusion loss 2 is 4.04547119140625 Diffusion loss 3 is 4.152300834655762
6%|▌ | 55/1000 [00:27<07:48, 2.02it/s]
Diffusion loss 0 is 3.2653720378875732 Diffusion loss 1 is 3.8671157360076904 Diffusion loss 2 is 4.136215686798096 Diffusion loss 3 is 4.24311637878418
6%|▌ | 56/1000 [00:27<07:40, 2.05it/s]
Diffusion loss 0 is 3.364753484725952 Diffusion loss 1 is 3.965951919555664 Diffusion loss 2 is 4.23487663269043 Diffusion loss 3 is 4.341814994812012
6%|▌ | 57/1000 [00:28<07:41, 2.04it/s]
Diffusion loss 0 is 3.456723690032959 Diffusion loss 1 is 4.057671070098877 Diffusion loss 2 is 4.3264546394348145 Diffusion loss 3 is 4.433437824249268
6%|▌ | 58/1000 [00:28<07:41, 2.04it/s]
Diffusion loss 0 is 3.54109787940979 Diffusion loss 1 is 4.1416401863098145 Diffusion loss 2 is 4.410285949707031 Diffusion loss 3 is 4.5173115730285645
6%|▌ | 59/1000 [00:29<07:37, 2.06it/s]
Diffusion loss 0 is 3.6124675273895264 Diffusion loss 1 is 4.212712287902832 Diffusion loss 2 is 4.4812469482421875 Diffusion loss 3 is 4.588311672210693
6%|▌ | 60/1000 [00:29<07:36, 2.06it/s]
Diffusion loss 0 is 3.684149980545044 Diffusion loss 1 is 4.284224033355713 Diffusion loss 2 is 4.552694320678711 Diffusion loss 3 is 4.659811019897461
6%|▌ | 61/1000 [00:29<07:37, 2.05it/s]
Diffusion loss 0 is 3.755814790725708 Diffusion loss 1 is 4.355713844299316 Diffusion loss 2 is 4.624081611633301 Diffusion loss 3 is 4.73121452331543
6%|▌ | 62/1000 [00:30<07:33, 2.07it/s]
Diffusion loss 0 is 3.826239585876465 Diffusion loss 1 is 4.425790786743164 Diffusion loss 2 is 4.6940999031066895 Diffusion loss 3 is 4.801302909851074
6%|▋ | 63/1000 [00:30<07:32, 2.07it/s]
Diffusion loss 0 is 3.8880934715270996 Diffusion loss 1 is 4.48719596862793 Diffusion loss 2 is 4.755460739135742 Diffusion loss 3 is 4.862725257873535
6%|▋ | 64/1000 [00:31<07:33, 2.07it/s]
Diffusion loss 0 is 3.9578845500946045 Diffusion loss 1 is 4.556700706481934 Diffusion loss 2 is 4.824909210205078 Diffusion loss 3 is 4.932220935821533
6%|▋ | 65/1000 [00:31<07:32, 2.07it/s]
Diffusion loss 0 is 4.036159992218018 Diffusion loss 1 is 4.634644031524658 Diffusion loss 2 is 4.902827262878418 Diffusion loss 3 is 5.010200023651123
7%|▋ | 66/1000 [00:32<07:40, 2.03it/s]
Diffusion loss 0 is 4.108985424041748 Diffusion loss 1 is 4.707339286804199 Diffusion loss 2 is 4.975471019744873 Diffusion loss 3 is 5.082864284515381
7%|▋ | 67/1000 [00:32<07:49, 1.99it/s]
Diffusion loss 0 is 4.177502632141113 Diffusion loss 1 is 4.775913238525391 Diffusion loss 2 is 5.043994426727295 Diffusion loss 3 is 5.151389122009277
7%|▋ | 68/1000 [00:33<07:43, 2.01it/s]
Diffusion loss 0 is 4.250778675079346 Diffusion loss 1 is 4.848609924316406 Diffusion loss 2 is 5.116585731506348 Diffusion loss 3 is 5.223998546600342
7%|▋ | 69/1000 [00:33<07:38, 2.03it/s]
Diffusion loss 0 is 4.32791805267334 Diffusion loss 1 is 4.925227165222168 Diffusion loss 2 is 5.193116188049316 Diffusion loss 3 is 5.300550937652588
7%|▋ | 70/1000 [00:34<07:33, 2.05it/s]
Diffusion loss 0 is 4.400700569152832 Diffusion loss 1 is 4.996689319610596 Diffusion loss 2 is 5.264405250549316 Diffusion loss 3 is 5.371856212615967
7%|▋ | 71/1000 [00:34<07:39, 2.02it/s]
Diffusion loss 0 is 4.4817891120910645 Diffusion loss 1 is 5.077026844024658 Diffusion loss 2 is 5.344655513763428 Diffusion loss 3 is 5.452112674713135
7%|▋ | 72/1000 [00:35<07:35, 2.04it/s]
Diffusion loss 0 is 4.568131923675537 Diffusion loss 1 is 5.162829399108887 Diffusion loss 2 is 5.430361747741699 Diffusion loss 3 is 5.537817478179932
7%|▋ | 73/1000 [00:35<07:42, 2.00it/s]
Diffusion loss 0 is 4.634922027587891 Diffusion loss 1 is 5.22886848449707 Diffusion loss 2 is 5.496279716491699 Diffusion loss 3 is 5.603742599487305
7%|▋ | 74/1000 [00:36<07:50, 1.97it/s]
Diffusion loss 0 is 4.698731422424316 Diffusion loss 1 is 5.292564392089844 Diffusion loss 2 is 5.559948921203613 Diffusion loss 3 is 5.667421817779541
8%|▊ | 75/1000 [00:36<07:51, 1.96it/s]
Diffusion loss 0 is 4.782403469085693 Diffusion loss 1 is 5.376194477081299 Diffusion loss 2 is 5.643599987030029 Diffusion loss 3 is 5.751086711883545
8%|▊ | 76/1000 [00:37<07:46, 1.98it/s]
Diffusion loss 0 is 4.858648300170898 Diffusion loss 1 is 5.452186584472656 Diffusion loss 2 is 5.719559669494629 Diffusion loss 3 is 5.827061176300049
8%|▊ | 77/1000 [00:37<07:44, 1.99it/s]
Diffusion loss 0 is 4.942154407501221 Diffusion loss 1 is 5.535473346710205 Diffusion loss 2 is 5.80284309387207 Diffusion loss 3 is 5.910357475280762
8%|▊ | 78/1000 [00:38<07:44, 1.98it/s]
Diffusion loss 0 is 4.980816841125488 Diffusion loss 1 is 5.573556900024414 Diffusion loss 2 is 5.840866565704346 Diffusion loss 3 is 5.948390960693359
8%|▊ | 79/1000 [00:38<07:41, 1.99it/s]
Diffusion loss 0 is 4.983737945556641 Diffusion loss 1 is 5.576536655426025 Diffusion loss 2 is 5.843900203704834 Diffusion loss 3 is 5.951450824737549
8%|▊ | 80/1000 [00:39<07:35, 2.02it/s]
Diffusion loss 0 is 5.002470970153809 Diffusion loss 1 is 5.595953941345215 Diffusion loss 2 is 5.863428115844727 Diffusion loss 3 is 5.970972061157227
8%|▊ | 81/1000 [00:39<07:38, 2.01it/s]
Diffusion loss 0 is 5.013365268707275 Diffusion loss 1 is 5.60740852355957 Diffusion loss 2 is 5.875025749206543 Diffusion loss 3 is 5.982600212097168
8%|▊ | 82/1000 [00:40<07:33, 2.02it/s]
Diffusion loss 0 is 5.014937400817871 Diffusion loss 1 is 5.608851909637451 Diffusion loss 2 is 5.87650203704834 Diffusion loss 3 is 5.984106540679932
8%|▊ | 83/1000 [00:40<07:26, 2.05it/s]
Diffusion loss 0 is 5.036396026611328 Diffusion loss 1 is 5.6298956871032715 Diffusion loss 2 is 5.897563934326172 Diffusion loss 3 is 6.005195140838623
8%|▊ | 84/1000 [00:41<07:30, 2.03it/s]
Diffusion loss 0 is 5.060315132141113 Diffusion loss 1 is 5.65366792678833 Diffusion loss 2 is 5.921413421630859 Diffusion loss 3 is 6.029091835021973
8%|▊ | 85/1000 [00:41<07:25, 2.05it/s]
Diffusion loss 0 is 5.117713928222656 Diffusion loss 1 is 5.710878372192383 Diffusion loss 2 is 5.978695869445801 Diffusion loss 3 is 6.086450099945068
9%|▊ | 86/1000 [00:42<07:19, 2.08it/s]
Diffusion loss 0 is 5.15546989440918 Diffusion loss 1 is 5.748571395874023 Diffusion loss 2 is 6.016458034515381 Diffusion loss 3 is 6.124281406402588
9%|▊ | 87/1000 [00:42<07:06, 2.14it/s]
Diffusion loss 0 is 5.207911968231201 Diffusion loss 1 is 5.8009161949157715 Diffusion loss 2 is 6.068885803222656 Diffusion loss 3 is 6.176769256591797
9%|▉ | 88/1000 [00:43<07:06, 2.14it/s]
Diffusion loss 0 is 5.2189435958862305 Diffusion loss 1 is 5.811193466186523 Diffusion loss 2 is 6.079214572906494 Diffusion loss 3 is 6.187180042266846
9%|▉ | 89/1000 [00:43<07:05, 2.14it/s]
Diffusion loss 0 is 5.236286640167236 Diffusion loss 1 is 5.8288679122924805 Diffusion loss 2 is 6.097057819366455 Diffusion loss 3 is 6.205112934112549
9%|▉ | 90/1000 [00:44<07:16, 2.09it/s]
Diffusion loss 0 is 5.239874839782715 Diffusion loss 1 is 5.833438396453857 Diffusion loss 2 is 6.101773738861084 Diffusion loss 3 is 6.209824562072754
9%|▉ | 91/1000 [00:44<07:13, 2.09it/s]
Diffusion loss 0 is 5.269533157348633 Diffusion loss 1 is 5.863594055175781 Diffusion loss 2 is 6.132017612457275 Diffusion loss 3 is 6.24006986618042
9%|▉ | 92/1000 [00:45<07:15, 2.08it/s]
Diffusion loss 0 is 5.2967424392700195 Diffusion loss 1 is 5.891350746154785 Diffusion loss 2 is 6.159935474395752 Diffusion loss 3 is 6.268002986907959
9%|▉ | 93/1000 [00:45<07:18, 2.07it/s]
Diffusion loss 0 is 5.284158706665039 Diffusion loss 1 is 5.879601955413818 Diffusion loss 2 is 6.148329257965088 Diffusion loss 3 is 6.256409645080566
9%|▉ | 94/1000 [00:46<07:17, 2.07it/s]
Diffusion loss 0 is 5.30442476272583 Diffusion loss 1 is 5.9002861976623535 Diffusion loss 2 is 6.1691131591796875 Diffusion loss 3 is 6.277200222015381
10%|▉ | 95/1000 [00:46<07:16, 2.07it/s]
Diffusion loss 0 is 5.331946849822998 Diffusion loss 1 is 5.928137302398682 Diffusion loss 2 is 6.197081565856934 Diffusion loss 3 is 6.3051910400390625
10%|▉ | 96/1000 [00:47<07:29, 2.01it/s]
Diffusion loss 0 is 5.368183612823486 Diffusion loss 1 is 5.964594841003418 Diffusion loss 2 is 6.2335638999938965 Diffusion loss 3 is 6.341667652130127
10%|▉ | 97/1000 [00:47<07:29, 2.01it/s]
Diffusion loss 0 is 5.424108982086182 Diffusion loss 1 is 6.020517826080322 Diffusion loss 2 is 6.289473056793213 Diffusion loss 3 is 6.3975605964660645
10%|▉ | 98/1000 [00:48<07:25, 2.02it/s]
Diffusion loss 0 is 5.432829856872559 Diffusion loss 1 is 6.029191017150879 Diffusion loss 2 is 6.29815149307251 Diffusion loss 3 is 6.406239986419678
10%|▉ | 99/1000 [00:48<07:21, 2.04it/s]
Diffusion loss 0 is 5.474944591522217 Diffusion loss 1 is 6.071273326873779 Diffusion loss 2 is 6.340230464935303 Diffusion loss 3 is 6.448312759399414
10%|█ | 100/1000 [00:49<07:20, 2.04it/s]
Diffusion loss 0 is 5.499778747558594 Diffusion loss 1 is 6.09612512588501 Diffusion loss 2 is 6.365105628967285 Diffusion loss 3 is 6.473206043243408 Diffusion loss 0 is 5.499994277954102 Diffusion loss 1 is 6.096709728240967 Diffusion loss 2 is 6.3657636642456055 Diffusion loss 3 is 6.473880767822266 EPOCH 100. Loss 6.473880767822266. Flow strength 5.0. Heatmap of P embedding is
10%|█ | 102/1000 [00:50<08:20, 1.80it/s]
Diffusion loss 0 is 5.473235607147217 Diffusion loss 1 is 6.0702033042907715 Diffusion loss 2 is 6.339330196380615 Diffusion loss 3 is 6.447473049163818
10%|█ | 103/1000 [00:50<08:02, 1.86it/s]
Diffusion loss 0 is 5.344992160797119 Diffusion loss 1 is 5.9423041343688965 Diffusion loss 2 is 6.211526870727539 Diffusion loss 3 is 6.319693088531494
10%|█ | 104/1000 [00:51<07:37, 1.96it/s]
Diffusion loss 0 is 5.331064224243164 Diffusion loss 1 is 5.928541660308838 Diffusion loss 2 is 6.197878837585449 Diffusion loss 3 is 6.306141376495361
10%|█ | 105/1000 [00:51<07:27, 2.00it/s]
Diffusion loss 0 is 5.345742225646973 Diffusion loss 1 is 5.943017482757568 Diffusion loss 2 is 6.212438106536865 Diffusion loss 3 is 6.320810794830322
11%|█ | 106/1000 [00:52<07:22, 2.02it/s]
Diffusion loss 0 is 5.370551586151123 Diffusion loss 1 is 5.966885566711426 Diffusion loss 2 is 6.236337661743164 Diffusion loss 3 is 6.344841003417969
11%|█ | 107/1000 [00:52<07:21, 2.02it/s]
Diffusion loss 0 is 5.422659397125244 Diffusion loss 1 is 6.018204689025879 Diffusion loss 2 is 6.287646293640137 Diffusion loss 3 is 6.396261692047119
11%|█ | 108/1000 [00:53<07:18, 2.03it/s]
Diffusion loss 0 is 5.466982364654541 Diffusion loss 1 is 6.061953544616699 Diffusion loss 2 is 6.331338882446289 Diffusion loss 3 is 6.440021514892578
11%|█ | 109/1000 [00:53<07:21, 2.02it/s]
Diffusion loss 0 is 5.528647422790527 Diffusion loss 1 is 6.122838497161865 Diffusion loss 2 is 6.392098903656006 Diffusion loss 3 is 6.500798225402832
11%|█ | 110/1000 [00:54<07:19, 2.02it/s]
Diffusion loss 0 is 5.574257850646973 Diffusion loss 1 is 6.167624473571777 Diffusion loss 2 is 6.436777591705322 Diffusion loss 3 is 6.54550313949585
11%|█ | 111/1000 [00:54<07:18, 2.03it/s]
Diffusion loss 0 is 5.619692325592041 Diffusion loss 1 is 6.211854934692383 Diffusion loss 2 is 6.480871200561523 Diffusion loss 3 is 6.5896124839782715
11%|█ | 112/1000 [00:55<07:19, 2.02it/s]
Diffusion loss 0 is 5.6691155433654785 Diffusion loss 1 is 6.2608642578125 Diffusion loss 2 is 6.5297675132751465 Diffusion loss 3 is 6.638502597808838
11%|█▏ | 113/1000 [00:55<07:14, 2.04it/s]
Diffusion loss 0 is 5.720432281494141 Diffusion loss 1 is 6.311663627624512 Diffusion loss 2 is 6.5804667472839355 Diffusion loss 3 is 6.6891913414001465
11%|█▏ | 114/1000 [00:56<07:15, 2.04it/s]
Diffusion loss 0 is 5.786249160766602 Diffusion loss 1 is 6.376924514770508 Diffusion loss 2 is 6.645636558532715 Diffusion loss 3 is 6.754351615905762
12%|█▏ | 115/1000 [00:56<07:12, 2.05it/s]
Diffusion loss 0 is 5.829891204833984 Diffusion loss 1 is 6.420021057128906 Diffusion loss 2 is 6.6885552406311035 Diffusion loss 3 is 6.797219753265381
12%|█▏ | 116/1000 [00:57<07:10, 2.05it/s]
Diffusion loss 0 is 5.871068477630615 Diffusion loss 1 is 6.460274696350098 Diffusion loss 2 is 6.7286152839660645 Diffusion loss 3 is 6.837244987487793
12%|█▏ | 117/1000 [00:57<07:11, 2.05it/s]
Diffusion loss 0 is 5.992489814758301 Diffusion loss 1 is 6.5808491706848145 Diffusion loss 2 is 6.848996639251709 Diffusion loss 3 is 6.957601070404053
12%|█▏ | 118/1000 [00:58<07:08, 2.06it/s]
Diffusion loss 0 is 6.055846214294434 Diffusion loss 1 is 6.643019199371338 Diffusion loss 2 is 6.910974502563477 Diffusion loss 3 is 7.019532203674316
12%|█▏ | 119/1000 [00:58<07:11, 2.04it/s]
Diffusion loss 0 is 6.175665855407715 Diffusion loss 1 is 6.761777877807617 Diffusion loss 2 is 7.029616832733154 Diffusion loss 3 is 7.138182163238525
12%|█▏ | 120/1000 [00:59<07:22, 1.99it/s]
Diffusion loss 0 is 6.283575057983398 Diffusion loss 1 is 6.868997097015381 Diffusion loss 2 is 7.13682222366333 Diffusion loss 3 is 7.245453834533691
12%|█▏ | 121/1000 [00:59<07:21, 1.99it/s]
Diffusion loss 0 is 6.370048522949219 Diffusion loss 1 is 6.954488277435303 Diffusion loss 2 is 7.2222371101379395 Diffusion loss 3 is 7.330890655517578
12%|█▏ | 122/1000 [01:00<07:21, 1.99it/s]
Diffusion loss 0 is 6.386755466461182 Diffusion loss 1 is 6.9706525802612305 Diffusion loss 2 is 7.238280296325684 Diffusion loss 3 is 7.346877574920654
12%|█▏ | 123/1000 [01:00<07:24, 1.97it/s]
Diffusion loss 0 is 6.411212921142578 Diffusion loss 1 is 6.994761943817139 Diffusion loss 2 is 7.262314319610596 Diffusion loss 3 is 7.370904922485352
12%|█▏ | 124/1000 [01:01<07:17, 2.00it/s]
Diffusion loss 0 is 6.412768363952637 Diffusion loss 1 is 6.996185779571533 Diffusion loss 2 is 7.263677597045898 Diffusion loss 3 is 7.372255802154541
12%|█▎ | 125/1000 [01:01<07:14, 2.01it/s]
Diffusion loss 0 is 6.4177565574646 Diffusion loss 1 is 7.001051902770996 Diffusion loss 2 is 7.268502712249756 Diffusion loss 3 is 7.377071380615234
13%|█▎ | 126/1000 [01:02<07:16, 2.00it/s]
Diffusion loss 0 is 6.438887596130371 Diffusion loss 1 is 7.021982192993164 Diffusion loss 2 is 7.289422988891602 Diffusion loss 3 is 7.39799690246582
13%|█▎ | 127/1000 [01:02<07:14, 2.01it/s]
Diffusion loss 0 is 6.466792106628418 Diffusion loss 1 is 7.049732208251953 Diffusion loss 2 is 7.317075729370117 Diffusion loss 3 is 7.425604820251465
13%|█▎ | 128/1000 [01:03<07:07, 2.04it/s]
Diffusion loss 0 is 6.474811553955078 Diffusion loss 1 is 7.057693004608154 Diffusion loss 2 is 7.325007915496826 Diffusion loss 3 is 7.433533191680908
13%|█▎ | 129/1000 [01:03<07:05, 2.05it/s]
Diffusion loss 0 is 6.484582424163818 Diffusion loss 1 is 7.067434310913086 Diffusion loss 2 is 7.334718704223633 Diffusion loss 3 is 7.443238258361816
13%|█▎ | 130/1000 [01:04<07:03, 2.06it/s]
Diffusion loss 0 is 6.502688884735107 Diffusion loss 1 is 7.085456371307373 Diffusion loss 2 is 7.352733135223389 Diffusion loss 3 is 7.46126127243042
13%|█▎ | 131/1000 [01:04<07:00, 2.07it/s]
Diffusion loss 0 is 6.516218662261963 Diffusion loss 1 is 7.098977088928223 Diffusion loss 2 is 7.366218566894531 Diffusion loss 3 is 7.474731922149658
13%|█▎ | 132/1000 [01:05<06:56, 2.08it/s]
Diffusion loss 0 is 6.513696670532227 Diffusion loss 1 is 7.096558570861816 Diffusion loss 2 is 7.363754749298096 Diffusion loss 3 is 7.472245216369629
13%|█▎ | 133/1000 [01:05<06:54, 2.09it/s]
Diffusion loss 0 is 6.503399848937988 Diffusion loss 1 is 7.08630895614624 Diffusion loss 2 is 7.35345983505249 Diffusion loss 3 is 7.461925983428955
13%|█▎ | 134/1000 [01:06<06:53, 2.09it/s]
Diffusion loss 0 is 6.512179851531982 Diffusion loss 1 is 7.095236301422119 Diffusion loss 2 is 7.362363815307617 Diffusion loss 3 is 7.47081184387207
14%|█▎ | 135/1000 [01:06<06:52, 2.10it/s]
Diffusion loss 0 is 6.532444953918457 Diffusion loss 1 is 7.115674018859863 Diffusion loss 2 is 7.382816314697266 Diffusion loss 3 is 7.491260528564453
14%|█▎ | 136/1000 [01:06<06:53, 2.09it/s]
Diffusion loss 0 is 6.53363561630249 Diffusion loss 1 is 7.116999626159668 Diffusion loss 2 is 7.384142875671387 Diffusion loss 3 is 7.492580413818359
14%|█▎ | 137/1000 [01:07<06:53, 2.09it/s]
Diffusion loss 0 is 6.539397716522217 Diffusion loss 1 is 7.122920036315918 Diffusion loss 2 is 7.390068054199219 Diffusion loss 3 is 7.4985032081604
14%|█▍ | 138/1000 [01:07<06:45, 2.12it/s]
Diffusion loss 0 is 6.530201435089111 Diffusion loss 1 is 7.11392068862915 Diffusion loss 2 is 7.381082534790039 Diffusion loss 3 is 7.4895195960998535
14%|█▍ | 139/1000 [01:08<06:47, 2.11it/s]
Diffusion loss 0 is 6.53363561630249 Diffusion loss 1 is 7.1175689697265625 Diffusion loss 2 is 7.384732723236084 Diffusion loss 3 is 7.493175029754639
14%|█▍ | 140/1000 [01:08<06:47, 2.11it/s]
Diffusion loss 0 is 6.529900550842285 Diffusion loss 1 is 7.114009380340576 Diffusion loss 2 is 7.381196022033691 Diffusion loss 3 is 7.489655017852783
14%|█▍ | 141/1000 [01:09<06:47, 2.11it/s]
Diffusion loss 0 is 6.521722316741943 Diffusion loss 1 is 7.1060638427734375 Diffusion loss 2 is 7.373279094696045 Diffusion loss 3 is 7.481754779815674
14%|█▍ | 142/1000 [01:09<06:41, 2.14it/s]
Diffusion loss 0 is 6.518465518951416 Diffusion loss 1 is 7.103052139282227 Diffusion loss 2 is 7.370311737060547 Diffusion loss 3 is 7.478806972503662
14%|█▍ | 143/1000 [01:10<06:43, 2.13it/s]
Diffusion loss 0 is 6.511197090148926 Diffusion loss 1 is 7.095990180969238 Diffusion loss 2 is 7.3632965087890625 Diffusion loss 3 is 7.4718146324157715
14%|█▍ | 144/1000 [01:10<06:53, 2.07it/s]
Diffusion loss 0 is 6.518250465393066 Diffusion loss 1 is 7.103304386138916 Diffusion loss 2 is 7.370654106140137 Diffusion loss 3 is 7.479190826416016
14%|█▍ | 145/1000 [01:11<06:57, 2.05it/s]
Diffusion loss 0 is 6.525689601898193 Diffusion loss 1 is 7.110926151275635 Diffusion loss 2 is 7.378272533416748 Diffusion loss 3 is 7.4868268966674805
15%|█▍ | 146/1000 [01:11<06:57, 2.05it/s]
Diffusion loss 0 is 6.523996829986572 Diffusion loss 1 is 7.109498500823975 Diffusion loss 2 is 7.376922607421875 Diffusion loss 3 is 7.485509872436523
15%|█▍ | 147/1000 [01:12<06:53, 2.06it/s]
Diffusion loss 0 is 6.522687911987305 Diffusion loss 1 is 7.1084747314453125 Diffusion loss 2 is 7.375977039337158 Diffusion loss 3 is 7.484596252441406
15%|█▍ | 148/1000 [01:12<06:50, 2.07it/s]
Diffusion loss 0 is 6.5222601890563965 Diffusion loss 1 is 7.1082963943481445 Diffusion loss 2 is 7.375864028930664 Diffusion loss 3 is 7.4845170974731445
15%|█▍ | 149/1000 [01:13<06:49, 2.08it/s]
Diffusion loss 0 is 6.516529083251953 Diffusion loss 1 is 7.102811336517334 Diffusion loss 2 is 7.370441913604736 Diffusion loss 3 is 7.4791259765625
15%|█▌ | 150/1000 [01:13<06:52, 2.06it/s]
Diffusion loss 0 is 6.5141730308532715 Diffusion loss 1 is 7.100712776184082 Diffusion loss 2 is 7.368415832519531 Diffusion loss 3 is 7.477129936218262
15%|█▌ | 151/1000 [01:14<06:51, 2.06it/s]
Diffusion loss 0 is 6.500705718994141 Diffusion loss 1 is 7.087507247924805 Diffusion loss 2 is 7.35528564453125 Diffusion loss 3 is 7.464030742645264
15%|█▌ | 152/1000 [01:14<06:56, 2.03it/s]
Diffusion loss 0 is 6.48459529876709 Diffusion loss 1 is 7.07165002822876 Diffusion loss 2 is 7.3394880294799805 Diffusion loss 3 is 7.4482622146606445
15%|█▌ | 153/1000 [01:15<06:55, 2.04it/s]
Diffusion loss 0 is 6.482011795043945 Diffusion loss 1 is 7.069344520568848 Diffusion loss 2 is 7.337277412414551 Diffusion loss 3 is 7.446089267730713
15%|█▌ | 154/1000 [01:15<06:50, 2.06it/s]
Diffusion loss 0 is 6.477871417999268 Diffusion loss 1 is 7.065436363220215 Diffusion loss 2 is 7.333425521850586 Diffusion loss 3 is 7.442266941070557
16%|█▌ | 155/1000 [01:16<06:53, 2.04it/s]
Diffusion loss 0 is 6.472053050994873 Diffusion loss 1 is 7.059813976287842 Diffusion loss 2 is 7.3278656005859375 Diffusion loss 3 is 7.436738967895508
16%|█▌ | 156/1000 [01:16<06:56, 2.03it/s]
Diffusion loss 0 is 6.4665985107421875 Diffusion loss 1 is 7.054559707641602 Diffusion loss 2 is 7.322671413421631 Diffusion loss 3 is 7.431575775146484
16%|█▌ | 157/1000 [01:17<06:50, 2.05it/s]
Diffusion loss 0 is 6.4656453132629395 Diffusion loss 1 is 7.05380392074585 Diffusion loss 2 is 7.321977138519287 Diffusion loss 3 is 7.430912017822266
16%|█▌ | 158/1000 [01:17<06:49, 2.06it/s]
Diffusion loss 0 is 6.460987091064453 Diffusion loss 1 is 7.049344539642334 Diffusion loss 2 is 7.317564964294434 Diffusion loss 3 is 7.426525592803955
16%|█▌ | 159/1000 [01:18<06:46, 2.07it/s]
Diffusion loss 0 is 6.457934856414795 Diffusion loss 1 is 7.046521186828613 Diffusion loss 2 is 7.314817905426025 Diffusion loss 3 is 7.42380428314209
16%|█▌ | 160/1000 [01:18<06:46, 2.07it/s]
Diffusion loss 0 is 6.451666831970215 Diffusion loss 1 is 7.040447235107422 Diffusion loss 2 is 7.308803081512451 Diffusion loss 3 is 7.417813301086426
16%|█▌ | 161/1000 [01:19<06:46, 2.06it/s]
Diffusion loss 0 is 6.444648742675781 Diffusion loss 1 is 7.033660888671875 Diffusion loss 2 is 7.302082061767578 Diffusion loss 3 is 7.411116123199463
16%|█▌ | 162/1000 [01:19<06:49, 2.05it/s]
Diffusion loss 0 is 6.438013076782227 Diffusion loss 1 is 7.02719259262085 Diffusion loss 2 is 7.295661449432373 Diffusion loss 3 is 7.404717922210693
16%|█▋ | 163/1000 [01:20<07:01, 1.99it/s]
Diffusion loss 0 is 6.429450035095215 Diffusion loss 1 is 7.018774032592773 Diffusion loss 2 is 7.287292957305908 Diffusion loss 3 is 7.3963751792907715
16%|█▋ | 164/1000 [01:20<06:55, 2.01it/s]
Diffusion loss 0 is 6.419259548187256 Diffusion loss 1 is 7.00877571105957 Diffusion loss 2 is 7.2773590087890625 Diffusion loss 3 is 7.386464595794678
16%|█▋ | 165/1000 [01:21<06:54, 2.01it/s]
Diffusion loss 0 is 6.414787769317627 Diffusion loss 1 is 7.004441261291504 Diffusion loss 2 is 7.273070812225342 Diffusion loss 3 is 7.382201194763184
17%|█▋ | 166/1000 [01:21<06:51, 2.03it/s]
Diffusion loss 0 is 6.412536144256592 Diffusion loss 1 is 7.002317905426025 Diffusion loss 2 is 7.270997524261475 Diffusion loss 3 is 7.380152225494385
17%|█▋ | 167/1000 [01:22<06:51, 2.03it/s]
Diffusion loss 0 is 6.410750389099121 Diffusion loss 1 is 7.000671863555908 Diffusion loss 2 is 7.269404411315918 Diffusion loss 3 is 7.3785834312438965
17%|█▋ | 168/1000 [01:22<06:48, 2.04it/s]
Diffusion loss 0 is 6.407662391662598 Diffusion loss 1 is 6.997723579406738 Diffusion loss 2 is 7.266497611999512 Diffusion loss 3 is 7.375698566436768
17%|█▋ | 169/1000 [01:23<06:45, 2.05it/s]
Diffusion loss 0 is 6.403969764709473 Diffusion loss 1 is 6.994171142578125 Diffusion loss 2 is 7.2629876136779785 Diffusion loss 3 is 7.3722100257873535
17%|█▋ | 170/1000 [01:23<06:43, 2.06it/s]
Diffusion loss 0 is 6.398262023925781 Diffusion loss 1 is 6.988638877868652 Diffusion loss 2 is 7.25752067565918 Diffusion loss 3 is 7.366765022277832
17%|█▋ | 171/1000 [01:24<06:50, 2.02it/s]
Diffusion loss 0 is 6.394654273986816 Diffusion loss 1 is 6.985171794891357 Diffusion loss 2 is 7.254098892211914 Diffusion loss 3 is 7.363365173339844
17%|█▋ | 172/1000 [01:24<06:52, 2.01it/s]
Diffusion loss 0 is 6.392330169677734 Diffusion loss 1 is 6.982992649078369 Diffusion loss 2 is 7.251969814300537 Diffusion loss 3 is 7.361259460449219
17%|█▋ | 173/1000 [01:25<06:55, 1.99it/s]
Diffusion loss 0 is 6.385924816131592 Diffusion loss 1 is 6.976747512817383 Diffusion loss 2 is 7.245765209197998 Diffusion loss 3 is 7.355074405670166
17%|█▋ | 174/1000 [01:25<06:53, 2.00it/s]
Diffusion loss 0 is 6.379782199859619 Diffusion loss 1 is 6.9707722663879395 Diffusion loss 2 is 7.239848613739014 Diffusion loss 3 is 7.349180698394775
18%|█▊ | 175/1000 [01:26<06:43, 2.05it/s]
Diffusion loss 0 is 6.377206325531006 Diffusion loss 1 is 6.96836519241333 Diffusion loss 2 is 7.237486839294434 Diffusion loss 3 is 7.34683895111084
18%|█▊ | 176/1000 [01:26<06:49, 2.01it/s]
Diffusion loss 0 is 6.371298313140869 Diffusion loss 1 is 6.962624549865723 Diffusion loss 2 is 7.231804370880127 Diffusion loss 3 is 7.341179370880127
18%|█▊ | 177/1000 [01:27<06:49, 2.01it/s]
Diffusion loss 0 is 6.369543552398682 Diffusion loss 1 is 6.961040496826172 Diffusion loss 2 is 7.230274200439453 Diffusion loss 3 is 7.339670658111572
18%|█▊ | 178/1000 [01:27<06:46, 2.02it/s]
Diffusion loss 0 is 6.365799427032471 Diffusion loss 1 is 6.957461357116699 Diffusion loss 2 is 7.226768493652344 Diffusion loss 3 is 7.336190223693848
18%|█▊ | 179/1000 [01:27<06:44, 2.03it/s]
Diffusion loss 0 is 6.363340854644775 Diffusion loss 1 is 6.955168724060059 Diffusion loss 2 is 7.224531173706055 Diffusion loss 3 is 7.333974838256836
18%|█▊ | 180/1000 [01:28<06:44, 2.03it/s]
Diffusion loss 0 is 6.360254764556885 Diffusion loss 1 is 6.952263832092285 Diffusion loss 2 is 7.221676826477051 Diffusion loss 3 is 7.331140041351318
18%|█▊ | 181/1000 [01:29<06:51, 1.99it/s]
Diffusion loss 0 is 6.359091281890869 Diffusion loss 1 is 6.95126485824585 Diffusion loss 2 is 7.220726490020752 Diffusion loss 3 is 7.330211162567139
18%|█▊ | 182/1000 [01:29<06:53, 1.98it/s]
Diffusion loss 0 is 6.356137752532959 Diffusion loss 1 is 6.948476314544678 Diffusion loss 2 is 7.21799898147583 Diffusion loss 3 is 7.327507495880127
18%|█▊ | 183/1000 [01:30<06:56, 1.96it/s]
Diffusion loss 0 is 6.353737831115723 Diffusion loss 1 is 6.946262359619141 Diffusion loss 2 is 7.215847492218018 Diffusion loss 3 is 7.325376987457275
18%|█▊ | 184/1000 [01:30<06:54, 1.97it/s]
Diffusion loss 0 is 6.352316856384277 Diffusion loss 1 is 6.945024490356445 Diffusion loss 2 is 7.214667320251465 Diffusion loss 3 is 7.32421875
18%|█▊ | 185/1000 [01:31<06:52, 1.97it/s]
Diffusion loss 0 is 6.350625038146973 Diffusion loss 1 is 6.943505764007568 Diffusion loss 2 is 7.213208198547363 Diffusion loss 3 is 7.322782516479492
19%|█▊ | 186/1000 [01:31<06:46, 2.00it/s]
Diffusion loss 0 is 6.349148273468018 Diffusion loss 1 is 6.942235946655273 Diffusion loss 2 is 7.211997985839844 Diffusion loss 3 is 7.321590423583984
19%|█▊ | 187/1000 [01:32<06:41, 2.03it/s]
Diffusion loss 0 is 6.347304821014404 Diffusion loss 1 is 6.940607070922852 Diffusion loss 2 is 7.210438251495361 Diffusion loss 3 is 7.32004976272583
19%|█▉ | 188/1000 [01:32<06:38, 2.04it/s]
Diffusion loss 0 is 6.346271991729736 Diffusion loss 1 is 6.939781188964844 Diffusion loss 2 is 7.209672927856445 Diffusion loss 3 is 7.319302558898926
19%|█▉ | 189/1000 [01:32<06:35, 2.05it/s]
Diffusion loss 0 is 6.344757080078125 Diffusion loss 1 is 6.938443183898926 Diffusion loss 2 is 7.208402156829834 Diffusion loss 3 is 7.318054676055908
19%|█▉ | 190/1000 [01:33<06:31, 2.07it/s]
Diffusion loss 0 is 6.3428120613098145 Diffusion loss 1 is 6.936678886413574 Diffusion loss 2 is 7.206702709197998 Diffusion loss 3 is 7.316375255584717
19%|█▉ | 191/1000 [01:33<06:31, 2.07it/s]
Diffusion loss 0 is 6.340928077697754 Diffusion loss 1 is 6.934995174407959 Diffusion loss 2 is 7.205080986022949 Diffusion loss 3 is 7.314770221710205
19%|█▉ | 192/1000 [01:34<06:29, 2.08it/s]
Diffusion loss 0 is 6.339548587799072 Diffusion loss 1 is 6.933799743652344 Diffusion loss 2 is 7.203946590423584 Diffusion loss 3 is 7.313654899597168
19%|█▉ | 193/1000 [01:34<06:29, 2.07it/s]
Diffusion loss 0 is 6.337381362915039 Diffusion loss 1 is 6.93182897567749 Diffusion loss 2 is 7.202036380767822 Diffusion loss 3 is 7.311762809753418
19%|█▉ | 194/1000 [01:35<06:30, 2.06it/s]
Diffusion loss 0 is 6.335491180419922 Diffusion loss 1 is 6.930113792419434 Diffusion loss 2 is 7.200369358062744 Diffusion loss 3 is 7.310113906860352
20%|█▉ | 195/1000 [01:35<06:25, 2.09it/s]
Diffusion loss 0 is 6.333651065826416 Diffusion loss 1 is 6.928426742553711 Diffusion loss 2 is 7.198739528656006 Diffusion loss 3 is 7.308506965637207
20%|█▉ | 196/1000 [01:36<06:32, 2.05it/s]
Diffusion loss 0 is 6.331992149353027 Diffusion loss 1 is 6.926973342895508 Diffusion loss 2 is 7.197351932525635 Diffusion loss 3 is 7.307137489318848
20%|█▉ | 197/1000 [01:36<06:22, 2.10it/s]
Diffusion loss 0 is 6.330036640167236 Diffusion loss 1 is 6.925209999084473 Diffusion loss 2 is 7.195651531219482 Diffusion loss 3 is 7.305456161499023
20%|█▉ | 198/1000 [01:37<06:19, 2.11it/s]
Diffusion loss 0 is 6.328657150268555 Diffusion loss 1 is 6.924005508422852 Diffusion loss 2 is 7.194499492645264 Diffusion loss 3 is 7.304322719573975
20%|█▉ | 199/1000 [01:37<06:22, 2.10it/s]
Diffusion loss 0 is 6.327692031860352 Diffusion loss 1 is 6.923242092132568 Diffusion loss 2 is 7.1938042640686035 Diffusion loss 3 is 7.303646087646484
20%|██ | 200/1000 [01:38<06:17, 2.12it/s]
Diffusion loss 0 is 6.325066089630127 Diffusion loss 1 is 6.920807838439941 Diffusion loss 2 is 7.191437721252441 Diffusion loss 3 is 7.301299095153809 Diffusion loss 0 is 6.323544979095459 Diffusion loss 1 is 6.919504642486572 Diffusion loss 2 is 7.190210819244385 Diffusion loss 3 is 7.300090312957764 EPOCH 200. Loss 7.300090312957764. Flow strength 5.0. Heatmap of P embedding is
20%|██ | 202/1000 [01:39<07:17, 1.83it/s]
Diffusion loss 0 is 6.322741985321045 Diffusion loss 1 is 6.918931007385254 Diffusion loss 2 is 7.189703941345215 Diffusion loss 3 is 7.299598217010498
20%|██ | 203/1000 [01:40<07:05, 1.87it/s]
Diffusion loss 0 is 6.320585250854492 Diffusion loss 1 is 6.917019844055176 Diffusion loss 2 is 7.187854290008545 Diffusion loss 3 is 7.297760963439941
20%|██ | 204/1000 [01:40<06:59, 1.90it/s]
Diffusion loss 0 is 6.319456577301025 Diffusion loss 1 is 6.916112899780273 Diffusion loss 2 is 7.187027454376221 Diffusion loss 3 is 7.296954154968262
20%|██ | 205/1000 [01:41<06:50, 1.94it/s]
Diffusion loss 0 is 6.317010402679443 Diffusion loss 1 is 6.913835048675537 Diffusion loss 2 is 7.184808254241943 Diffusion loss 3 is 7.29475736618042
21%|██ | 206/1000 [01:41<06:52, 1.92it/s]
Diffusion loss 0 is 6.315019607543945 Diffusion loss 1 is 6.912021636962891 Diffusion loss 2 is 7.183060169219971 Diffusion loss 3 is 7.293031215667725
21%|██ | 207/1000 [01:42<06:54, 1.91it/s]
Diffusion loss 0 is 6.314259052276611 Diffusion loss 1 is 6.911486625671387 Diffusion loss 2 is 7.1825947761535645 Diffusion loss 3 is 7.2925848960876465
21%|██ | 208/1000 [01:42<06:47, 1.94it/s]
Diffusion loss 0 is 6.313827991485596 Diffusion loss 1 is 6.911313533782959 Diffusion loss 2 is 7.182487487792969 Diffusion loss 3 is 7.292491436004639
21%|██ | 209/1000 [01:43<06:45, 1.95it/s]
Diffusion loss 0 is 6.311731338500977 Diffusion loss 1 is 6.909470558166504 Diffusion loss 2 is 7.180724620819092 Diffusion loss 3 is 7.290747165679932
21%|██ | 210/1000 [01:43<06:36, 1.99it/s]
Diffusion loss 0 is 6.3102874755859375 Diffusion loss 1 is 6.9082722663879395 Diffusion loss 2 is 7.179598331451416 Diffusion loss 3 is 7.289639472961426
21%|██ | 211/1000 [01:44<06:36, 1.99it/s]
Diffusion loss 0 is 6.309597969055176 Diffusion loss 1 is 6.907873630523682 Diffusion loss 2 is 7.179278373718262 Diffusion loss 3 is 7.289333343505859
21%|██ | 212/1000 [01:44<06:32, 2.01it/s]
Diffusion loss 0 is 6.308292865753174 Diffusion loss 1 is 6.906872272491455 Diffusion loss 2 is 7.1783599853515625 Diffusion loss 3 is 7.288427352905273
21%|██▏ | 213/1000 [01:45<06:26, 2.04it/s]
Diffusion loss 0 is 6.306446552276611 Diffusion loss 1 is 6.9053425788879395 Diffusion loss 2 is 7.17691707611084 Diffusion loss 3 is 7.286996364593506
21%|██▏ | 214/1000 [01:45<06:22, 2.05it/s]
Diffusion loss 0 is 6.305623531341553 Diffusion loss 1 is 6.904825210571289 Diffusion loss 2 is 7.176483154296875 Diffusion loss 3 is 7.2865753173828125
22%|██▏ | 215/1000 [01:45<06:20, 2.06it/s]
Diffusion loss 0 is 6.303237438201904 Diffusion loss 1 is 6.902757167816162 Diffusion loss 2 is 7.1745076179504395 Diffusion loss 3 is 7.284612655639648
22%|██▏ | 216/1000 [01:46<06:20, 2.06it/s]
Diffusion loss 0 is 6.302562236785889 Diffusion loss 1 is 6.902388095855713 Diffusion loss 2 is 7.174225330352783 Diffusion loss 3 is 7.284344673156738
22%|██▏ | 217/1000 [01:46<06:18, 2.07it/s]
Diffusion loss 0 is 6.3008904457092285 Diffusion loss 1 is 6.901004791259766 Diffusion loss 2 is 7.17293119430542 Diffusion loss 3 is 7.283068656921387
22%|██▏ | 218/1000 [01:47<06:12, 2.10it/s]
Diffusion loss 0 is 6.300303936004639 Diffusion loss 1 is 6.900711536407471 Diffusion loss 2 is 7.172723293304443 Diffusion loss 3 is 7.2828779220581055
22%|██▏ | 219/1000 [01:47<06:17, 2.07it/s]
Diffusion loss 0 is 6.298389911651611 Diffusion loss 1 is 6.899105548858643 Diffusion loss 2 is 7.171204566955566 Diffusion loss 3 is 7.281374931335449
22%|██▏ | 220/1000 [01:48<06:07, 2.12it/s]
Diffusion loss 0 is 6.296774387359619 Diffusion loss 1 is 6.897794723510742 Diffusion loss 2 is 7.1699748039245605 Diffusion loss 3 is 7.2801594734191895
22%|██▏ | 221/1000 [01:48<06:05, 2.13it/s]
Diffusion loss 0 is 6.296182632446289 Diffusion loss 1 is 6.897543430328369 Diffusion loss 2 is 7.169819355010986 Diffusion loss 3 is 7.280017375946045
22%|██▏ | 222/1000 [01:49<06:03, 2.14it/s]
Diffusion loss 0 is 6.296102523803711 Diffusion loss 1 is 6.8978047370910645 Diffusion loss 2 is 7.17017126083374 Diffusion loss 3 is 7.2803802490234375
22%|██▏ | 223/1000 [01:49<06:06, 2.12it/s]
Diffusion loss 0 is 6.295217990875244 Diffusion loss 1 is 6.897256374359131 Diffusion loss 2 is 7.169706344604492 Diffusion loss 3 is 7.2799248695373535
22%|██▏ | 224/1000 [01:50<06:05, 2.12it/s]
Diffusion loss 0 is 6.294277667999268 Diffusion loss 1 is 6.896657943725586 Diffusion loss 2 is 7.169213771820068 Diffusion loss 3 is 7.279446125030518
22%|██▎ | 225/1000 [01:50<06:02, 2.14it/s]
Diffusion loss 0 is 6.293025493621826 Diffusion loss 1 is 6.895723819732666 Diffusion loss 2 is 7.168375015258789 Diffusion loss 3 is 7.278622627258301
23%|██▎ | 226/1000 [01:51<06:00, 2.15it/s]
Diffusion loss 0 is 6.292055130004883 Diffusion loss 1 is 6.895068168640137 Diffusion loss 2 is 7.16780424118042 Diffusion loss 3 is 7.278065204620361
23%|██▎ | 227/1000 [01:51<06:12, 2.07it/s]
Diffusion loss 0 is 6.290655136108398 Diffusion loss 1 is 6.893962860107422 Diffusion loss 2 is 7.166786193847656 Diffusion loss 3 is 7.277065277099609
23%|██▎ | 228/1000 [01:52<06:07, 2.10it/s]
Diffusion loss 0 is 6.289546012878418 Diffusion loss 1 is 6.893168926239014 Diffusion loss 2 is 7.166076183319092 Diffusion loss 3 is 7.276371479034424
23%|██▎ | 229/1000 [01:52<06:07, 2.10it/s]
Diffusion loss 0 is 6.288367748260498 Diffusion loss 1 is 6.892324447631836 Diffusion loss 2 is 7.165329933166504 Diffusion loss 3 is 7.275643348693848
23%|██▎ | 230/1000 [01:53<06:06, 2.10it/s]
Diffusion loss 0 is 6.2877092361450195 Diffusion loss 1 is 6.891982555389404 Diffusion loss 2 is 7.165079116821289 Diffusion loss 3 is 7.275410175323486
23%|██▎ | 231/1000 [01:53<06:05, 2.10it/s]
Diffusion loss 0 is 6.286772727966309 Diffusion loss 1 is 6.8913798332214355 Diffusion loss 2 is 7.164571762084961 Diffusion loss 3 is 7.2749176025390625
23%|██▎ | 232/1000 [01:54<06:05, 2.10it/s]
Diffusion loss 0 is 6.287400722503662 Diffusion loss 1 is 6.892341613769531 Diffusion loss 2 is 7.16562032699585 Diffusion loss 3 is 7.2759785652160645
23%|██▎ | 233/1000 [01:54<06:05, 2.10it/s]
Diffusion loss 0 is 6.286788463592529 Diffusion loss 1 is 6.892077922821045 Diffusion loss 2 is 7.165451526641846 Diffusion loss 3 is 7.275822639465332
23%|██▎ | 234/1000 [01:55<06:08, 2.08it/s]
Diffusion loss 0 is 6.286257266998291 Diffusion loss 1 is 6.891904354095459 Diffusion loss 2 is 7.165379524230957 Diffusion loss 3 is 7.275763034820557
24%|██▎ | 235/1000 [01:55<06:06, 2.09it/s]
Diffusion loss 0 is 6.285429000854492 Diffusion loss 1 is 6.891469478607178 Diffusion loss 2 is 7.165045261383057 Diffusion loss 3 is 7.2754364013671875
24%|██▎ | 236/1000 [01:55<06:05, 2.09it/s]
Diffusion loss 0 is 6.285285472869873 Diffusion loss 1 is 6.891699314117432 Diffusion loss 2 is 7.165380001068115 Diffusion loss 3 is 7.275783538818359
24%|██▎ | 237/1000 [01:56<06:02, 2.11it/s]
Diffusion loss 0 is 6.2843828201293945 Diffusion loss 1 is 6.891227722167969 Diffusion loss 2 is 7.165025234222412 Diffusion loss 3 is 7.275437355041504
24%|██▍ | 238/1000 [01:56<06:01, 2.11it/s]
Diffusion loss 0 is 6.283782005310059 Diffusion loss 1 is 6.891095161437988 Diffusion loss 2 is 7.164994239807129 Diffusion loss 3 is 7.275406837463379
24%|██▍ | 239/1000 [01:57<06:02, 2.10it/s]
Diffusion loss 0 is 6.282778739929199 Diffusion loss 1 is 6.890543460845947 Diffusion loss 2 is 7.164555072784424 Diffusion loss 3 is 7.274972915649414
24%|██▍ | 240/1000 [01:57<06:02, 2.10it/s]
Diffusion loss 0 is 6.28210973739624 Diffusion loss 1 is 6.8903093338012695 Diffusion loss 2 is 7.164427757263184 Diffusion loss 3 is 7.274852275848389
24%|██▍ | 241/1000 [01:58<06:00, 2.10it/s]
Diffusion loss 0 is 6.28136682510376 Diffusion loss 1 is 6.890007972717285 Diffusion loss 2 is 7.164236068725586 Diffusion loss 3 is 7.274672985076904
24%|██▍ | 242/1000 [01:58<05:55, 2.13it/s]
Diffusion loss 0 is 6.281215667724609 Diffusion loss 1 is 6.890303611755371 Diffusion loss 2 is 7.164648056030273 Diffusion loss 3 is 7.275094985961914
24%|██▍ | 243/1000 [01:59<05:58, 2.11it/s]
Diffusion loss 0 is 6.282698631286621 Diffusion loss 1 is 6.892261505126953 Diffusion loss 2 is 7.166714668273926 Diffusion loss 3 is 7.277167797088623
24%|██▍ | 244/1000 [01:59<05:56, 2.12it/s]
Diffusion loss 0 is 6.281570911407471 Diffusion loss 1 is 6.891645908355713 Diffusion loss 2 is 7.166227340698242 Diffusion loss 3 is 7.276682376861572
24%|██▍ | 245/1000 [02:00<05:55, 2.12it/s]
Diffusion loss 0 is 6.280299663543701 Diffusion loss 1 is 6.8909125328063965 Diffusion loss 2 is 7.165618419647217 Diffusion loss 3 is 7.276068687438965
25%|██▍ | 246/1000 [02:00<06:09, 2.04it/s]
Diffusion loss 0 is 6.2791829109191895 Diffusion loss 1 is 6.89030122756958 Diffusion loss 2 is 7.165118217468262 Diffusion loss 3 is 7.275563716888428
25%|██▍ | 247/1000 [02:01<06:06, 2.06it/s]
Diffusion loss 0 is 6.277928352355957 Diffusion loss 1 is 6.8895583152771 Diffusion loss 2 is 7.164485931396484 Diffusion loss 3 is 7.274930000305176
25%|██▍ | 248/1000 [02:01<06:03, 2.07it/s]
Diffusion loss 0 is 6.2770185470581055 Diffusion loss 1 is 6.889136791229248 Diffusion loss 2 is 7.164184093475342 Diffusion loss 3 is 7.2746405601501465
25%|██▍ | 249/1000 [02:02<06:02, 2.07it/s]
Diffusion loss 0 is 6.2767333984375 Diffusion loss 1 is 6.889347553253174 Diffusion loss 2 is 7.164516925811768 Diffusion loss 3 is 7.274985313415527
25%|██▌ | 250/1000 [02:02<06:00, 2.08it/s]
Diffusion loss 0 is 6.275850772857666 Diffusion loss 1 is 6.889028549194336 Diffusion loss 2 is 7.164322853088379 Diffusion loss 3 is 7.2747883796691895
25%|██▌ | 251/1000 [02:03<05:59, 2.08it/s]
Diffusion loss 0 is 6.274933338165283 Diffusion loss 1 is 6.888647079467773 Diffusion loss 2 is 7.164059638977051 Diffusion loss 3 is 7.274523735046387
25%|██▌ | 252/1000 [02:03<05:57, 2.09it/s]
Diffusion loss 0 is 6.273187637329102 Diffusion loss 1 is 6.887395858764648 Diffusion loss 2 is 7.162928104400635 Diffusion loss 3 is 7.2733917236328125
25%|██▌ | 253/1000 [02:04<05:59, 2.08it/s]
Diffusion loss 0 is 6.272842884063721 Diffusion loss 1 is 6.8876471519470215 Diffusion loss 2 is 7.163303375244141 Diffusion loss 3 is 7.273754119873047
25%|██▌ | 254/1000 [02:04<05:55, 2.10it/s]
Diffusion loss 0 is 6.273798942565918 Diffusion loss 1 is 6.889201641082764 Diffusion loss 2 is 7.164972305297852 Diffusion loss 3 is 7.275411128997803
26%|██▌ | 255/1000 [02:05<05:54, 2.10it/s]
Diffusion loss 0 is 6.272754669189453 Diffusion loss 1 is 6.888747215270996 Diffusion loss 2 is 7.164647579193115 Diffusion loss 3 is 7.275082588195801
26%|██▌ | 256/1000 [02:05<05:53, 2.11it/s]
Diffusion loss 0 is 6.272066593170166 Diffusion loss 1 is 6.888532638549805 Diffusion loss 2 is 7.16455602645874 Diffusion loss 3 is 7.275012969970703
26%|██▌ | 257/1000 [02:05<05:49, 2.13it/s]
Diffusion loss 0 is 6.273244857788086 Diffusion loss 1 is 6.890234470367432 Diffusion loss 2 is 7.166383743286133 Diffusion loss 3 is 7.276845932006836
26%|██▌ | 258/1000 [02:06<05:46, 2.14it/s]
Diffusion loss 0 is 6.272656440734863 Diffusion loss 1 is 6.890226364135742 Diffusion loss 2 is 7.166491508483887 Diffusion loss 3 is 7.276943206787109
26%|██▌ | 259/1000 [02:06<05:55, 2.09it/s]
Diffusion loss 0 is 6.274811744689941 Diffusion loss 1 is 6.892965316772461 Diffusion loss 2 is 7.16934871673584 Diffusion loss 3 is 7.279793739318848
26%|██▌ | 260/1000 [02:07<05:51, 2.11it/s]
Diffusion loss 0 is 6.273268222808838 Diffusion loss 1 is 6.892002582550049 Diffusion loss 2 is 7.16849946975708 Diffusion loss 3 is 7.278934478759766
26%|██▌ | 261/1000 [02:07<05:49, 2.11it/s]
Diffusion loss 0 is 6.270754337310791 Diffusion loss 1 is 6.890054225921631 Diffusion loss 2 is 7.16666841506958 Diffusion loss 3 is 7.27709436416626
26%|██▌ | 262/1000 [02:08<05:53, 2.09it/s]
Diffusion loss 0 is 6.271125793457031 Diffusion loss 1 is 6.890951633453369 Diffusion loss 2 is 7.1676926612854 Diffusion loss 3 is 7.2781219482421875
26%|██▋ | 263/1000 [02:08<05:52, 2.09it/s]
Diffusion loss 0 is 6.269778728485107 Diffusion loss 1 is 6.890154838562012 Diffusion loss 2 is 7.167013645172119 Diffusion loss 3 is 7.277441501617432
26%|██▋ | 264/1000 [02:09<05:53, 2.08it/s]
Diffusion loss 0 is 6.265632629394531 Diffusion loss 1 is 6.886617660522461 Diffusion loss 2 is 7.16359281539917 Diffusion loss 3 is 7.273998737335205
26%|██▋ | 265/1000 [02:09<05:53, 2.08it/s]
Diffusion loss 0 is 6.266018390655518 Diffusion loss 1 is 6.887603282928467 Diffusion loss 2 is 7.164683818817139 Diffusion loss 3 is 7.275060653686523
27%|██▋ | 266/1000 [02:10<05:52, 2.08it/s]
Diffusion loss 0 is 6.265402793884277 Diffusion loss 1 is 6.8875298500061035 Diffusion loss 2 is 7.164727687835693 Diffusion loss 3 is 7.275105953216553
27%|██▋ | 267/1000 [02:10<05:54, 2.07it/s]
Diffusion loss 0 is 6.264475345611572 Diffusion loss 1 is 6.887038230895996 Diffusion loss 2 is 7.164344787597656 Diffusion loss 3 is 7.274714469909668
27%|██▋ | 268/1000 [02:11<05:44, 2.12it/s]
Diffusion loss 0 is 6.261383056640625 Diffusion loss 1 is 6.883955478668213 Diffusion loss 2 is 7.161334991455078 Diffusion loss 3 is 7.271647930145264
27%|██▋ | 269/1000 [02:11<05:45, 2.12it/s]
Diffusion loss 0 is 6.261701583862305 Diffusion loss 1 is 6.883580684661865 Diffusion loss 2 is 7.161041736602783 Diffusion loss 3 is 7.2713751792907715
27%|██▋ | 270/1000 [02:12<05:47, 2.10it/s]
Diffusion loss 0 is 6.267710208892822 Diffusion loss 1 is 6.88992977142334 Diffusion loss 2 is 7.167638778686523 Diffusion loss 3 is 7.278119087219238
27%|██▋ | 271/1000 [02:12<05:45, 2.11it/s]
Diffusion loss 0 is 6.25568151473999 Diffusion loss 1 is 6.878565788269043 Diffusion loss 2 is 7.156843662261963 Diffusion loss 3 is 7.267836570739746
27%|██▋ | 272/1000 [02:13<05:44, 2.11it/s]
Diffusion loss 0 is 6.244182109832764 Diffusion loss 1 is 6.867252349853516 Diffusion loss 2 is 7.146116256713867 Diffusion loss 3 is 7.257583141326904
27%|██▋ | 273/1000 [02:13<05:42, 2.12it/s]
Diffusion loss 0 is 6.228621959686279 Diffusion loss 1 is 6.851802825927734 Diffusion loss 2 is 7.131226539611816 Diffusion loss 3 is 7.243093967437744
27%|██▋ | 274/1000 [02:14<05:43, 2.11it/s]
Diffusion loss 0 is 6.214014053344727 Diffusion loss 1 is 6.837202548980713 Diffusion loss 2 is 7.117147445678711 Diffusion loss 3 is 7.229361534118652
28%|██▊ | 275/1000 [02:14<05:41, 2.12it/s]
Diffusion loss 0 is 6.20033597946167 Diffusion loss 1 is 6.823360443115234 Diffusion loss 2 is 7.103989601135254 Diffusion loss 3 is 7.216666221618652
28%|██▊ | 276/1000 [02:15<05:42, 2.11it/s]
Diffusion loss 0 is 6.1912055015563965 Diffusion loss 1 is 6.814488410949707 Diffusion loss 2 is 7.095688819885254 Diffusion loss 3 is 7.2086567878723145
28%|██▊ | 277/1000 [02:15<05:39, 2.13it/s]
Diffusion loss 0 is 6.17034387588501 Diffusion loss 1 is 6.793891906738281 Diffusion loss 2 is 7.075628280639648 Diffusion loss 3 is 7.188848495483398
28%|██▊ | 278/1000 [02:15<05:42, 2.11it/s]
Diffusion loss 0 is 6.15336799621582 Diffusion loss 1 is 6.777338027954102 Diffusion loss 2 is 7.059620380401611 Diffusion loss 3 is 7.173064231872559
28%|██▊ | 279/1000 [02:16<05:51, 2.05it/s]
Diffusion loss 0 is 6.136073589324951 Diffusion loss 1 is 6.760651111602783 Diffusion loss 2 is 7.043425559997559 Diffusion loss 3 is 7.157032012939453
28%|██▊ | 280/1000 [02:16<05:48, 2.07it/s]
Diffusion loss 0 is 6.1205244064331055 Diffusion loss 1 is 6.745795726776123 Diffusion loss 2 is 7.029045581817627 Diffusion loss 3 is 7.142786979675293
28%|██▊ | 281/1000 [02:17<05:51, 2.04it/s]
Diffusion loss 0 is 6.107326030731201 Diffusion loss 1 is 6.733419418334961 Diffusion loss 2 is 7.017035007476807 Diffusion loss 3 is 7.1308274269104
28%|██▊ | 282/1000 [02:17<05:49, 2.06it/s]
Diffusion loss 0 is 6.094951152801514 Diffusion loss 1 is 6.721971035003662 Diffusion loss 2 is 7.005856513977051 Diffusion loss 3 is 7.11964225769043
28%|██▊ | 283/1000 [02:18<05:52, 2.04it/s]
Diffusion loss 0 is 6.081467151641846 Diffusion loss 1 is 6.709475994110107 Diffusion loss 2 is 6.993545055389404 Diffusion loss 3 is 7.107264518737793
28%|██▊ | 284/1000 [02:18<05:54, 2.02it/s]
Diffusion loss 0 is 6.070034503936768 Diffusion loss 1 is 6.699057102203369 Diffusion loss 2 is 6.983234882354736 Diffusion loss 3 is 7.09684944152832
28%|██▊ | 285/1000 [02:19<05:51, 2.03it/s]
Diffusion loss 0 is 6.05271577835083 Diffusion loss 1 is 6.682714939117432 Diffusion loss 2 is 6.967071056365967 Diffusion loss 3 is 7.080641269683838
29%|██▊ | 286/1000 [02:19<05:43, 2.08it/s]
Diffusion loss 0 is 6.037919998168945 Diffusion loss 1 is 6.668928623199463 Diffusion loss 2 is 6.953463554382324 Diffusion loss 3 is 7.0669941902160645
29%|██▊ | 287/1000 [02:20<05:52, 2.03it/s]
Diffusion loss 0 is 6.025206089019775 Diffusion loss 1 is 6.657227039337158 Diffusion loss 2 is 6.94187593460083 Diffusion loss 3 is 7.055328845977783
29%|██▉ | 288/1000 [02:20<05:48, 2.04it/s]
Diffusion loss 0 is 6.010578155517578 Diffusion loss 1 is 6.643560409545898 Diffusion loss 2 is 6.928254127502441 Diffusion loss 3 is 7.041598320007324
29%|██▉ | 289/1000 [02:21<05:53, 2.01it/s]
Diffusion loss 0 is 5.99567985534668 Diffusion loss 1 is 6.6296000480651855 Diffusion loss 2 is 6.914327144622803 Diffusion loss 3 is 7.027562618255615
29%|██▉ | 290/1000 [02:21<05:56, 1.99it/s]
Diffusion loss 0 is 5.980838775634766 Diffusion loss 1 is 6.615710258483887 Diffusion loss 2 is 6.900392532348633 Diffusion loss 3 is 7.013477325439453
29%|██▉ | 291/1000 [02:22<05:52, 2.01it/s]
Diffusion loss 0 is 5.965755462646484 Diffusion loss 1 is 6.601563453674316 Diffusion loss 2 is 6.886212348937988 Diffusion loss 3 is 6.999160289764404
29%|██▉ | 292/1000 [02:22<05:46, 2.04it/s]
Diffusion loss 0 is 5.9512248039245605 Diffusion loss 1 is 6.587932586669922 Diffusion loss 2 is 6.872469902038574 Diffusion loss 3 is 6.985241413116455
29%|██▉ | 293/1000 [02:23<05:42, 2.06it/s]
Diffusion loss 0 is 5.9371657371521 Diffusion loss 1 is 6.574778079986572 Diffusion loss 2 is 6.859223365783691 Diffusion loss 3 is 6.971837043762207
29%|██▉ | 294/1000 [02:23<05:38, 2.08it/s]
Diffusion loss 0 is 5.921340465545654 Diffusion loss 1 is 6.559799671173096 Diffusion loss 2 is 6.844103813171387 Diffusion loss 3 is 6.956538677215576
30%|██▉ | 295/1000 [02:24<05:38, 2.08it/s]
Diffusion loss 0 is 5.905215740203857 Diffusion loss 1 is 6.5445098876953125 Diffusion loss 2 is 6.8287034034729 Diffusion loss 3 is 6.940987586975098
30%|██▉ | 296/1000 [02:24<05:36, 2.09it/s]
Diffusion loss 0 is 5.891111373901367 Diffusion loss 1 is 6.531189918518066 Diffusion loss 2 is 6.815183162689209 Diffusion loss 3 is 6.92726993560791
30%|██▉ | 297/1000 [02:25<05:35, 2.10it/s]
Diffusion loss 0 is 5.878739356994629 Diffusion loss 1 is 6.519484519958496 Diffusion loss 2 is 6.803241729736328 Diffusion loss 3 is 6.9151225090026855
30%|██▉ | 298/1000 [02:25<05:38, 2.07it/s]
Diffusion loss 0 is 5.863198757171631 Diffusion loss 1 is 6.50465726852417 Diffusion loss 2 is 6.788244724273682 Diffusion loss 3 is 6.899957656860352
30%|██▉ | 299/1000 [02:26<05:36, 2.09it/s]
Diffusion loss 0 is 5.8482346534729 Diffusion loss 1 is 6.490304946899414 Diffusion loss 2 is 6.773712158203125 Diffusion loss 3 is 6.885264873504639
30%|███ | 300/1000 [02:26<05:34, 2.09it/s]
Diffusion loss 0 is 5.833312511444092 Diffusion loss 1 is 6.475955486297607 Diffusion loss 2 is 6.7591423988342285 Diffusion loss 3 is 6.8705153465271 Diffusion loss 0 is 5.816686153411865 Diffusion loss 1 is 6.459866523742676 Diffusion loss 2 is 6.742828369140625 Diffusion loss 3 is 6.854021072387695 EPOCH 300. Loss 6.854021072387695. Flow strength 5.0. Heatmap of P embedding is
30%|███ | 302/1000 [02:27<06:18, 1.84it/s]
Diffusion loss 0 is 5.804393768310547 Diffusion loss 1 is 6.448009490966797 Diffusion loss 2 is 6.730683326721191 Diffusion loss 3 is 6.84166955947876
30%|███ | 302/1000 [02:28<05:42, 2.04it/s]
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) /Users/adjourner/Projects/directed_graphs/05c Multiscale Diffusion Flow Embedding.ipynb Cell 15 in <cell line: 1>() ----> <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=0'>1</a> MFE.fit(n_steps=1000) /Users/adjourner/Projects/directed_graphs/05c Multiscale Diffusion Flow Embedding.ipynb Cell 15 in MultiscaleDiffusionFlowEmbedder.fit(self, n_steps) <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=233'>234</a> self.optim.zero_grad() <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=234'>235</a> # compute loss --> <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=235'>236</a> loss = self.loss() <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=236'>237</a> if loss.isnan(): <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=237'>238</a> print("Final loss was nan") /Users/adjourner/Projects/directed_graphs/05c Multiscale Diffusion Flow Embedding.ipynb Cell 15 in MultiscaleDiffusionFlowEmbedder.loss(self) <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=115'>116</a> # compute diffusion loss on embedded points <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=116'>117</a> if self.diffusion_loss != 0: --> <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=117'>118</a> diffusion_loss = self.diffusion_loss() <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=118'>119</a> else: <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=119'>120</a> diffusion_loss = 0 /Users/adjourner/Projects/directed_graphs/05c Multiscale Diffusion Flow Embedding.ipynb Cell 15 in MultiscaleDiffusionFlowEmbedder.diffusion_loss(self) <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=87'>88</a> # normalize embedded points to lie within -self.embedding_bounds, self.embedding_bounds <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=88'>89</a> # if any are trying to escape, constrain them to lie on the edges <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=89'>90</a> # self.embedded_points[:,0][torch.abs(self.embedded_points[:,0]) > self.embedding_bounds] = self.embedding_bounds * (self.embedded_points[:,0][torch.abs(self.embedded_points[:,0]) > self.embedding_bounds])/torch.abs(self.embedded_points[:,0][torch.abs(self.embedded_points[:,0]) > self.embedding_bounds]) <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=90'>91</a> # self.embedded_points[:,1][torch.abs(self.embedded_points[:,1]) > self.embedding_bounds] = self.embedding_bounds * (self.embedded_points[:,1][torch.abs(self.embedded_points[:,1]) > self.embedding_bounds])/torch.abs(self.embedded_points[:,0][torch.abs(self.embedded_points[:,1]) > self.embedding_bounds]) <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=91'>92</a> # compute embedding diffusion matrix, using including diffusion to grid points <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=92'>93</a> for i,t in enumerate(self.ts): ---> <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=93'>94</a> self.P_embedding_ts[i] = diffusion_matrix_with_grid_points(X = self.embedded_points, grid=self.grid, flow_function = self.FlowArtist, t = t, sigma = self.sigma_embedding, flow_strength=self.flow_strength) <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=94'>95</a> # set any affinities of zero to a very small amount, to prevent the KL divergence from becoming infinite. <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=95'>96</a> # self.P_embedding_ts[i][self.P_embedding_ts[i] == 0] = self.epsilon # TODO: Perhaps enable later; this didn't cause NaN errors before <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=96'>97</a> self.P_embedding_ts[i] = self.P_embedding_ts[i] + self.epsilon /Users/adjourner/Projects/directed_graphs/05c Multiscale Diffusion Flow Embedding.ipynb Cell 15 in diffusion_matrix_with_grid_points(X, grid, flow_function, t, sigma, flow_strength) <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=8'>9</a> flow_per_point = flow_function(points_and_grid) <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=9'>10</a> # take a diffusion matrix ---> <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=10'>11</a> A = affinity_matrix_from_pointset_to_pointset(points_and_grid,points_and_grid, flows = flow_per_point, sigma = sigma, flow_strength=flow_strength) <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=11'>12</a> P = F.normalize(A, p=1, dim=-1) <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=12'>13</a> # TODO: Should we remove self affinities? Probably not, as lazy random walks are advantageous when powering <a href='vscode-notebook-cell:/Users/adjourner/Projects/directed_graphs/05c%20Multiscale%20Diffusion%20Flow%20Embedding.ipynb#ch0000013?line=13'>14</a> # Power the matrix to t steps File ~/Projects/directed_graphs/directed_graphs/diffusion_flow_embedding.py:82, in affinity_matrix_from_pointset_to_pointset(pointset1, pointset2, flows, n_neighbors, sigma, flow_strength) 80 P3 = P3.transpose(1,2) 81 # compute affinities from flows and directions ---> 82 affinities = affinity_from_flow(flows,P3,sigma=sigma,flow_strength=flow_strength) 83 return affinities File ~/Projects/directed_graphs/directed_graphs/diffusion_flow_embedding.py:31, in affinity_from_flow(flows, directions_array, flow_strength, sigma) 29 n_directions = directions_array.shape[0] 30 # Normalize directions ---> 31 length_of_directions = torch.linalg.norm(directions_array,dim=-1) 32 normed_directions = F.normalize(directions_array,dim=-1) 33 # and normalize flows # TODO: Perhaps reconsider 34 # Calculate flow lengths, used to scale directions to flow 35 # flow_lengths = torch.linalg.norm(flows,dim=-1) KeyboardInterrupt:
MFE.visualize_points(labels=labels)
No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
plt.plot(MFE.losses["diffusion"], label = "diffusion")
plt.plot(MFE.losses["smoothness"], label = 'smoothness')
plt.plot(MFE.losses["reconstruction"])
plt.legend()
<matplotlib.legend.Legend at 0x2b1bf09dfee0>
X, flow, labels = directed_swiss_roll_uniform(num_nodes=1000, num_spirals=2.5, radius=10, height=3, xtilt=0, ytilt=0)
plot_directed_3d(X, flow, labels, mask_prob=0.5)
X = torch.tensor(X)
flow = torch.tensor(flow)
X = X.float().to(device)
flow = flow.float().to(device)
MFE = MultiscaleDiffusionFlowEmbedder(X, flow, device=device).to(device)
MFE.fit(n_steps=10000)
0%| | 0/10000 [00:00<?, ?it/s]
EPOCH 0. Loss 11.891934394836426. Flow strength 4.999000072479248. Heatmap of P embedding is
1%| | 100/10000 [00:15<19:12, 8.59it/s]
EPOCH 100. Loss 8.272184371948242. Flow strength 4.952264785766602. Heatmap of P embedding is
2%|▏ | 200/10000 [00:30<26:09, 6.24it/s]
EPOCH 200. Loss 8.326654434204102. Flow strength 4.955779552459717. Heatmap of P embedding is
3%|▎ | 300/10000 [00:44<22:10, 7.29it/s]
EPOCH 300. Loss 8.010849952697754. Flow strength 4.956110954284668. Heatmap of P embedding is
4%|▍ | 400/10000 [00:59<22:36, 7.08it/s]
EPOCH 400. Loss 7.801275730133057. Flow strength 4.9596638679504395. Heatmap of P embedding is
5%|▌ | 500/10000 [01:14<21:53, 7.23it/s]
EPOCH 500. Loss 7.280424118041992. Flow strength 4.979403018951416. Heatmap of P embedding is
6%|▌ | 600/10000 [01:29<22:26, 6.98it/s]
EPOCH 600. Loss 6.438525199890137. Flow strength 4.982205867767334. Heatmap of P embedding is
7%|▋ | 700/10000 [01:44<23:43, 6.53it/s]
EPOCH 700. Loss 6.5063371658325195. Flow strength 4.990403175354004. Heatmap of P embedding is
8%|▊ | 800/10000 [01:58<17:17, 8.87it/s]
EPOCH 800. Loss 6.906248092651367. Flow strength 4.994604110717773. Heatmap of P embedding is
9%|▉ | 900/10000 [02:12<24:16, 6.25it/s]
EPOCH 900. Loss 6.449456214904785. Flow strength 4.994830131530762. Heatmap of P embedding is
10%|█ | 1000/10000 [02:26<20:38, 7.27it/s]
EPOCH 1000. Loss 6.436872482299805. Flow strength 4.991913318634033. Heatmap of P embedding is
11%|█ | 1100/10000 [02:39<16:41, 8.88it/s]
EPOCH 1100. Loss 6.227049827575684. Flow strength 4.983253002166748. Heatmap of P embedding is
12%|█▏ | 1200/10000 [02:51<20:18, 7.22it/s]
EPOCH 1200. Loss 6.149223327636719. Flow strength 4.982929706573486. Heatmap of P embedding is
13%|█▎ | 1300/10000 [03:06<19:53, 7.29it/s]
EPOCH 1300. Loss 5.984847068786621. Flow strength 4.990385055541992. Heatmap of P embedding is
14%|█▍ | 1400/10000 [03:20<19:42, 7.27it/s]
EPOCH 1400. Loss 5.991268634796143. Flow strength 4.995453357696533. Heatmap of P embedding is
15%|█▌ | 1500/10000 [03:35<19:28, 7.28it/s]
EPOCH 1500. Loss 5.852634906768799. Flow strength 5.002211093902588. Heatmap of P embedding is
16%|█▌ | 1600/10000 [03:50<19:15, 7.27it/s]
EPOCH 1600. Loss 5.773377895355225. Flow strength 5.013149738311768. Heatmap of P embedding is
17%|█▋ | 1700/10000 [04:04<18:34, 7.45it/s]
EPOCH 1700. Loss 5.805696964263916. Flow strength 5.04870080947876. Heatmap of P embedding is
18%|█▊ | 1800/10000 [04:19<19:17, 7.09it/s]
EPOCH 1800. Loss 5.712757110595703. Flow strength 5.038760662078857. Heatmap of P embedding is
19%|█▉ | 1900/10000 [04:33<18:37, 7.25it/s]
EPOCH 1900. Loss 5.071902751922607. Flow strength 5.019508361816406. Heatmap of P embedding is
20%|██ | 2000/10000 [04:47<17:01, 7.84it/s]
EPOCH 2000. Loss 4.494369983673096. Flow strength 5.034949779510498. Heatmap of P embedding is
21%|██ | 2100/10000 [05:01<21:47, 6.04it/s]
EPOCH 2100. Loss 3.9733169078826904. Flow strength 5.047738075256348. Heatmap of P embedding is
21%|██ | 2122/10000 [05:06<18:22, 7.15it/s]
!nbdev_build_lib
Converted 00_core.ipynb. Converted 01a Graph Utils.ipynb. Converted 01b01 Directed Graph Utils.ipynb. Converted 01c Visualizations.ipynb. Converted 01d Batch Testing Scripts.ipynb. Converted 02a Toy Graph Datasets.ipynb. Converted 02a01 Communities Datasets.ipynb. Converted 02a02 Small Random Directed Graphs.ipynb. Converted 02b Directed Stochastic Block Model.ipynb. Converted 02b01 Sink Graph with SBM.ipynb. Converted 02b02 SBM Source Graph.ipynb. Converted 02b03 Directed_SBM_On_Manifold.ipynb. Converted 02c Trees.ipynb. Converted 02d01_Snake_TestCase.ipynb. Converted 02d01_Snake_With5b (1).ipynb. Converted 02d02 Circles and Swiss roll Datasets.ipynb. Converted 02d02a Testing Diffusion Flow Embedder on Circle and Swiss Roll.ipynb. Converted 02d03 Spheres and Donuts with Flow (datasets).ipynb. Converted 02d04 Tree and Flow Artist Test.ipynb. Converted 02d05 Branch and Clusters.ipynb. Converted 02d05a Testing Branch Clusters Circle Swiss roll.ipynb. Converted 03 Node2Vec on Directed Graphs.ipynb. Converted 03a Node2Vec_with_Backwards_Connection.ipynb. Converted 03a01 DeepWalk_with_Backwards_Connections.ipynb. Converted 03b Backwards Node2Vec.ipynb. Converted 03b01 Testing Backwards Node2vec on DirectedGraphs.ipynb. Converted 03c node2vec graph reversal walk.ipynb. Converted 03c01 Node2Vec_on_RootedTree.ipynb. Converted 03c01a Node2Vec_on_MultiTree.ipynb. Converted 03c01b Node2Vec_on_PolyTree.ipynb. Converted 03c01c Node2Vec_on_Cyclic_Graphs.ipynb. Converted 04 UW_Manifold_Based_Directed_Embedding.ipynb. Converted 05a Testing Distance-based Flow Embedder.ipynb. Converted 05a01 Flow_Embedding_with_Implicit_Vectors.ipynb. Converted 05a02 Model Spaces for Flow Embeddings.ipynb. Converted 05b Flow Embedding with Diffusion.ipynb. Converted 05b01 Testing diffusion flow embedder.ipynb. Converted 05b01a Diffusion Flow Embedding on Pancreas_RNA_Velocity.ipynb. Converted 05b01a01 RNA Velocity Analysis.ipynb. Converted 05b01a02 rna-velo_viz_phate.ipynb. Converted 05b01a03 More_RNA_Velocity_Datasets.ipynb. Converted 05b02 Flow-Artist_Upgrade.ipynb. Converted 05b03 Flow-Artist-Scratch-Upgrade.ipynb. Converted 05b03 Testing the Flow-Affinity Matrix.ipynb. Converted 05b04 Diffusion Flow Embedder on Toy Data.ipynb. Converted 05b05 Test Swiss Roll with New Kernel (7-11).ipynb. Converted 05b05a DFE on the Swiss Roll with Smoothness.ipynb. Converted 05c Multiscale Diffusion Flow Embedding.ipynb. Converted 05c01 Loss functions.ipynb. Converted 05c02 Testing MFE on Swiss Roll.ipynb. Converted 05c03 Testing MFE on RNA Velocity Dataset.ipynb. Converted 99_playground.ipynb. Converted Untitled.ipynb. Converted index.ipynb.
from directed_graphs.datasets import plot_ribbon, sample_ribbon, plot_ribbon_samples
from directed_graphs.datasets import affinity_grid_search
Canceled future for execute_request message before replies were done
The Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.
plot_ribbon()
X, X_ = sample_ribbon()
labels = X
plot_ribbon_samples(X, X_)
X = torch.tensor(X)
X_ = torch.tensor(X_)
X = X.float().to(device)
X_ = X_.float().to(device)
affinity_grid_search(X, X_, sigmas=[0.5, 1, 3, 5], flow_strengths=[1, 2, 5, 10])
dfe = DiffusionFlowEmbedder(X,
X_,
t=1,
sigma_graph=15,
sigma_embedding=15,
device=device,
autoencoder_shape = [50,10],
flow_artist_shape=[30,20,10],
flow_strength_graph=2,
flow_strength_embedding=5,
learnable_flow_strength=True,
weight_of_flow = 0.5,
learning_rate=0.001
)
dfe = dfe.to(device)
embeddings = dfe.fit(n_steps=2000)
dfe.visualize_points(labels)